## Generate state-action data from a deterministic evaluation policy
import os
import argparse
import warnings
import random
import numpy as np
import pandas as pd
from tqdm import tqdm

from sklearn.model_selection import train_test_split
from sklearn.linear_model import LogisticRegression
from sklearn.exceptions import ConvergenceWarning

from src.utils import load_dataset

## Suppress warnings
warnings.filterwarnings(action="ignore", category=ConvergenceWarning)
warnings.filterwarnings(action="ignore", category=FutureWarning)

if __name__ == "__main__":
    parser = argparse.ArgumentParser()
    parser.add_argument("--dataset", "-d", type=str, required=True)
    parser.add_argument("--num_exps", "-n", type=int, default=200)
    parser.add_argument("--ope_ratio", "-ope", type=float, default=0.7)
    parser.add_argument("--random_state", "-r", type=int, default=42)
    args = parser.parse_args()
    print(args)

    dataset = args.dataset
    num_exps = args.num_exps
    ope_ratio = args.ope_ratio
    seed = args.random_state

    random.seed(seed)
    np.random.seed(seed)

    print(f'[Data: {dataset}] Generating actions from a deterministic evaluation policy...')
    
    os.makedirs(f'action_data/{dataset}', exist_ok=True)
    data_dict = load_dataset(data=dataset)

    ## Convert it into pandas for row_id tracking
    num_features = data_dict['X'].shape[-1]
    state_data_df = pd.DataFrame(data_dict['X'])
    state_data_df['y'] = data_dict['y']
    state_data_df.attrs['dataset'] = dataset
    state_data_df.attrs['num_features'] = num_features
    state_data_df.to_parquet(f'action_data/{dataset}/state_data.parquet', 
                              engine='pyarrow', 
                              compression='snappy')
    X = state_data_df.iloc[:, :num_features]
    y = state_data_df['y']

    action_data = []
    for exp_id in tqdm(range(num_exps)):
        X_ope, X_eval_gen, y_ope, y_eval_gen =  train_test_split(X, y, test_size=1-ope_ratio, 
                                                                 random_state=seed+exp_id, 
                                                                 shuffle=True)
        
        # Generate deterministic part of the evaluation policy (pi_det) 
        clf = LogisticRegression(
            random_state=seed+exp_id,
            solver="lbfgs",
            multi_class="multinomial"
        ).fit(X=X_eval_gen, y=y_eval_gen)
        y_eval_pol = np.array(clf.predict(X=X_ope).astype(int), dtype=np.int32)
        y_gt = np.array(y_ope, dtype=np.int32)
        action_data.append({'y_gt': y_gt, 
                            'y_eval_pol': y_eval_pol,
                            'row_indices': np.array(y_ope.index, dtype=np.int32)})

    action_data_df = pd.DataFrame(action_data)
    action_data_df.attrs['dataset'] = dataset
    action_data_df.attrs['num_classes'] = data_dict['n_class']
    action_data_df.to_parquet(f'action_data/{dataset}/action_data.parquet', 
                              engine='pyarrow', 
                              compression='snappy')